package com.googlecode.gaal.vis;

import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

import com.googlecode.gaal.analysis.api.Filter;
import com.googlecode.gaal.analysis.api.VectorBuilder;
import com.googlecode.gaal.analysis.impl.IntervalVectorBuilder;
import com.googlecode.gaal.analysis.impl.ProperIntervalSetBuilder;
import com.googlecode.gaal.analysis.impl.RecursiveIntervalExtractor;
import com.googlecode.gaal.analysis.impl.SimpleContextFilter;
import com.googlecode.gaal.data.api.Corpus;
import com.googlecode.gaal.data.api.IntSequence;
import com.googlecode.gaal.data.api.Vector;
import com.googlecode.gaal.data.impl.ArraySequence;
import com.googlecode.gaal.data.impl.SparseVector;
import com.googlecode.gaal.data.impl.TreeMapCorpus;
import com.googlecode.gaal.preprocess.api.Tokenizer;
import com.googlecode.gaal.preprocess.impl.LowerCaseNormalizer;
import com.googlecode.gaal.preprocess.impl.MultidocumentRegexTokenizer;
import com.googlecode.gaal.suffix.api.EmbeddedSuffixTree.EmbeddedInterval;
import com.googlecode.gaal.suffix.api.IntervalTree.Interval;
import com.googlecode.gaal.suffix.api.LinearizedSuffixTree;
import com.googlecode.gaal.suffix.impl.LinearizedSuffixTreeImpl;
import com.googlecode.gaal.vis.api.VectorDrawing;
import com.googlecode.gaal.vis.impl.TikzConstants;

public class VectorVisualizer {

    public static final String STRING_REGEX = "([A-Z0-9a-züäöß]+'?[a-z0-9]*)|[\\.,;\\(\\)]";

    private static final String DELIMITER = " ";

    private static final Set<String> SEPARATORS;
    static {
        SEPARATORS = new HashSet<String>();
        SEPARATORS.add(".");
        SEPARATORS.add(",");
    }

    private static final Vector X_AXIS = new SparseVector();
    static {
        X_AXIS.add(0, 1);
    }

    protected final Corpus<String> srcCorpus;
    protected final Corpus<String> dstCorpus;
    protected LinearizedSuffixTree srcLST;
    protected LinearizedSuffixTree dstLST;
    protected final ArraySequence srcSequence;
    protected final ArraySequence dstSequence;

    protected List<Vector> srcVectors;
    protected List<Vector> dstVectors;
    protected List<Interval> srcObjects;
    protected List<Interval> dstObjects;

    public VectorVisualizer(String srcFileName, String dstFileName, int windowSize) throws FileNotFoundException {
        FileReader srcReader = new FileReader(srcFileName);
        FileReader dstReader = new FileReader(dstFileName);
        Tokenizer<String> srcTokenizer = new MultidocumentRegexTokenizer(srcReader, STRING_REGEX,
                new LowerCaseNormalizer());
        Tokenizer<String> dstTokenizer = new MultidocumentRegexTokenizer(dstReader, STRING_REGEX,
                new LowerCaseNormalizer());
        srcCorpus = new TreeMapCorpus(srcTokenizer, SEPARATORS);
        dstCorpus = new TreeMapCorpus(dstTokenizer, SEPARATORS);
        // IntervalSetBuilder intervalSetBuilder = new SupermaximalSetBuilder();
        srcSequence = srcCorpus.sequence();
        dstSequence = dstCorpus.sequence();
        srcLST = new LinearizedSuffixTreeImpl(srcSequence, srcCorpus.alphabetSize());
        dstLST = new LinearizedSuffixTreeImpl(dstSequence, dstCorpus.alphabetSize());
        srcVectors = new ArrayList<Vector>();
        dstVectors = new ArrayList<Vector>();
        srcObjects = new ArrayList<Interval>();
        dstObjects = new ArrayList<Interval>();
        VectorBuilder<Interval> vectorBuilder = new IntervalVectorBuilder<Interval>();
        Filter<EmbeddedInterval> contextFilter = new SimpleContextFilter(0, 0);
        Iterator<Interval> srcIntervalIterator = new RecursiveIntervalExtractor(srcLST, srcCorpus,
                new ProperIntervalSetBuilder(), contextFilter, windowSize).iterator();
        Iterator<Interval> dstIntervalIterator = new RecursiveIntervalExtractor(dstLST, dstCorpus,
                new ProperIntervalSetBuilder(), contextFilter, windowSize).iterator();
        vectorBuilder.buildVectors(srcIntervalIterator, srcVectors, srcObjects, srcCorpus, -1);
        vectorBuilder.buildVectors(dstIntervalIterator, dstVectors, dstObjects, dstCorpus, -1);
    }

    public void visualizeVectors(VectorDrawing drawing) {
        List<String> labels = Arrays.asList(new String[] { "both$\\dots$and", "sowohl$\\dots$als auch" });
        Map<Integer, Vector> srcSelectedVectors = getVectors(srcCorpus, srcVectors, srcObjects, labels);
        Map<Integer, Vector> dstSelectedVectors = getVectors(dstCorpus, dstVectors, dstObjects, labels);

        drawVectors(drawing, srcSelectedVectors, labels);
        drawVectors(drawing, dstSelectedVectors, labels);

        drawAngles(drawing, new ArrayList<Vector>(srcSelectedVectors.values()), new ArrayList<Vector>(
                dstSelectedVectors.values()));

        Map<Integer, Integer> srcRepeatsMap = mapRepeats(srcCorpus, srcObjects, labels);
        Map<Integer, Integer> dstRepeatsMap = mapRepeats(dstCorpus, dstObjects, labels);

        List<String> srcDocuments = toDocumentList(srcCorpus, srcRepeatsMap);
        List<String> dstDocuments = toDocumentList(dstCorpus, dstRepeatsMap);
        drawDocuments(drawing, srcDocuments, dstDocuments);
    }

    public void visualizeRepeats(VectorDrawing drawing, List<String> labels) {
        Map<Integer, Vector> srcSelectedVectors = getVectors(dstCorpus, dstVectors, dstObjects, labels);

        drawVectors(drawing, srcSelectedVectors, labels);

        Map<Integer, Integer> srcRepeatsMap = mapRepeats(dstCorpus, dstObjects, labels);

        List<String> srcDocuments = toDocumentList(dstCorpus, srcRepeatsMap);
        drawDocuments(drawing, srcDocuments);
    }

    public void visualizeCorpus(VectorDrawing drawing) {
        drawing.drawDocument(toDocument(srcCorpus), false);
    }

    public void visualizeDocuments(VectorDrawing drawing) {

        @SuppressWarnings("unchecked")
        List<String> srcDocuments = toDocumentList(srcCorpus, Collections.EMPTY_MAP);
        drawDocuments(drawing, srcDocuments);
    }

    public void visualizeParallelCorpus(VectorDrawing drawing) {
        drawing.drawDocument(toDocument(srcCorpus), toDocument(dstCorpus), false);
    }

    public void visualizeParallelDocuments(VectorDrawing drawing) {

        @SuppressWarnings("unchecked")
        List<String> srcDocuments = toDocumentList(srcCorpus, Collections.EMPTY_MAP);
        @SuppressWarnings("unchecked")
        List<String> dstDocuments = toDocumentList(dstCorpus, Collections.EMPTY_MAP);
        drawDocuments(drawing, srcDocuments, dstDocuments);
    }

    public void visualizeTermDocumentMatrix(VectorDrawing drawing) {
        List<Map<String, Integer>> termMapList = toTermMapList(srcCorpus);
        List<String> documentLabels = new ArrayList<String>(termMapList.size());
        for (int i = 0; i < termMapList.size(); i++) {
            documentLabels.add(String.format("$d_%d$", i));
        }
        Set<String> termSet = new TreeSet<String>();
        for (Map<String, Integer> termMap : termMapList) {
            termSet.addAll(termMap.keySet());
        }
        List<String> termLabels = new ArrayList<String>(termSet);
        int[][] values = new int[termLabels.size()][documentLabels.size()];
        int docNumber = 0;
        for (Map<String, Integer> termMap : termMapList) {
            for (Map.Entry<String, Integer> entry : termMap.entrySet()) {
                int i = termLabels.indexOf(entry.getKey());
                values[i][docNumber] = entry.getValue();
            }
            docNumber++;
        }
        drawing.drawMatrix(termLabels, documentLabels, values);
    }

    public void visualizeWordContextMatrix(VectorDrawing drawing) {
        List<Map<String, Integer>> termMapList = toTermMapList(srcCorpus);
        List<String> documentLabels = new ArrayList<String>(termMapList.size());
        for (int i = 0; i < termMapList.size(); i++) {
            documentLabels.add(String.format("$d_%d$", i));
        }
        Set<String> termSet = new TreeSet<String>();
        for (Map<String, Integer> termMap : termMapList) {
            termSet.addAll(termMap.keySet());
        }
        List<String> termLabels = new ArrayList<String>(termSet);
        int[][] values = new int[documentLabels.size()][termLabels.size()];
        int docNumber = 0;
        for (Map<String, Integer> termMap : termMapList) {
            for (Map.Entry<String, Integer> entry : termMap.entrySet()) {
                int i = termLabels.indexOf(entry.getKey());
                values[docNumber][i] = entry.getValue();
            }
            docNumber++;
        }
        drawing.drawMatrix(documentLabels, termLabels, values);
    }

    public static String toString(Interval interval, Corpus<String> corpus) {
        if (interval instanceof EmbeddedInterval) {
            EmbeddedInterval embeddedInterval = (EmbeddedInterval) interval;
            return String.format("%s$\\dots$%s", toString(embeddedInterval.getEmbeddingInterval(), corpus),
                    corpus.toString(embeddedInterval.label(), DELIMITER));
        }
        return corpus.toString(interval.label(), DELIMITER);
    }

    public static void mapInterval(Interval interval, IntSequence indices, Map<Integer, Integer> repeatsMap, int number) {
        if (indices == null) {
            indices = interval.indices();
        }
        int lcp = interval.lcp();
        for (int i = 0; i < indices.size(); i++) {
            for (int j = 0; j < lcp; j++) {
                repeatsMap.put(indices.get(i) + j, number);
            }
        }
        if (interval instanceof EmbeddedInterval) {
            EmbeddedInterval embeddedInterval = (EmbeddedInterval) interval;
            mapInterval(embeddedInterval.getEmbeddingInterval(), embeddedInterval.embeddingIndices(), repeatsMap,
                    number);
        }
    }

    public static IntSequence removeIntervening(EmbeddedInterval interval) {
        IntSequence indices = interval.indices();
        IntSequence embeddingIndices = interval.embeddingIndices();
        System.out.format("ind:%s\n", indices);
        System.out.format("emb:%s\n", embeddingIndices);
        BitSet removedIndices = new BitSet(embeddingIndices.size());
        for (int i = 0; i < embeddingIndices.size(); i++) {
            if (!removedIndices.get(i)) {
                int index = indices.get(i);
                int embeddingIndex = embeddingIndices.get(i);
                for (int j = 0; j < embeddingIndices.size(); j++) {
                    if (i != j && !removedIndices.get(j) && embeddingIndices.get(j) == embeddingIndex) {
                        System.out.format("%d=%d\n", embeddingIndices.get(j), embeddingIndex);
                        if (indices.get(j) > index) {
                            System.out.format("removing i=%d\n", i);
                            removedIndices.set(i);
                        } else {
                            System.out.format("removing j=%d\n", j);
                            removedIndices.set(j);
                        }
                        break;
                    }
                }
            }
        }
        if (removedIndices.cardinality() == 0) {
            return indices;
        } else {
            int[] newIndices = new int[indices.size() - removedIndices.cardinality()];
            int counter = 0;
            for (int i = 0; i < indices.size(); i++) {
                if (!removedIndices.get(i)) {
                    newIndices[counter++] = indices.get(i);
                }
            }
            return new ArraySequence(newIndices);
        }
    }

    private static void drawVectors(VectorDrawing drawing, Map<Integer, Vector> selectedVectors, List<String> labels) {
        for (Map.Entry<Integer, Vector> entry : selectedVectors.entrySet()) {
            Vector vector = entry.getValue();
            int index = entry.getKey();
            drawing.drawVector(vector.get(0), vector.get(1), labels.get(index), TikzConstants.VECTOR_STYLES[index]);
        }
    }

    private static void drawAngles(VectorDrawing drawing, List<Vector> srcVectors, List<Vector> dstVectors) {
        for (int i = 0; i < srcVectors.size(); i++) {
            Vector srcVector = srcVectors.get(i);
            Vector dstVector = dstVectors.get(i);
            double srcToXAngle = Math.toDegrees(Math.acos(X_AXIS.similarity(srcVector)));
            double dstToXAngle = Math.toDegrees(Math.acos(X_AXIS.similarity(dstVector)));
            drawing.drawAngle(Math.min(srcToXAngle, dstToXAngle), Math.max(srcToXAngle, dstToXAngle), "$\\theta$");
        }
    }

    private static void drawDocuments(VectorDrawing drawing, List<String> srcDocuments, List<String> dstDocuments) {
        for (int i = 0; i < srcDocuments.size(); i++) {
            drawing.drawDocument(srcDocuments.get(i), dstDocuments.get(i), true);
        }
    }

    private static void drawDocuments(VectorDrawing drawing, List<String> documents) {
        for (int i = 0; i < documents.size(); i++) {
            drawing.drawDocument(documents.get(i), true);
        }
    }

    private static Map<Integer, Integer> mapRepeats(Corpus<String> corpus, List<Interval> intervals, List<String> labels) {
        Map<Integer, Integer> repeatsMap = new HashMap<Integer, Integer>();
        for (int i = 0; i < intervals.size(); i++) {
            Interval interval = intervals.get(i);
            String label = toString(interval, corpus);
            int index = labels.indexOf(label);
            if (index != -1) {
                mapInterval(interval, null, repeatsMap, index);
            }
        }
        return repeatsMap;
    }

    private static List<String> toDocumentList(Corpus<String> corpus, Map<Integer, Integer> repeatsMap) {
        IntSequence sequence = corpus.sequence();
        List<String> documents = new ArrayList<String>();
        StringBuilder sb = new StringBuilder();
        int prevRepeatNumber = -1;
        int docNumber = 0;
        boolean isFirst = true;
        boolean isBlockStart = false;
        for (int i = 0; i < sequence.size(); i++) {
            Integer repeatNumber = repeatsMap.get(i);
            if (prevRepeatNumber != -1 && (repeatNumber == null || prevRepeatNumber != repeatNumber)) {
                sb.append("}");
            }
            if (repeatNumber != null && prevRepeatNumber != repeatNumber) {
                sb.append(String.format(" {\\color{%s}", TikzConstants.VECTOR_STYLES[repeatNumber].getColour()));
                isBlockStart = true;
            }
            if (repeatNumber != null) {
                prevRepeatNumber = repeatNumber;
            } else {
                prevRepeatNumber = -1;
            }
            int docId = corpus.getDocumentId(i);
            if (docNumber != docId) {
                docNumber = docId;
                documents.add(sb.toString());
                sb = new StringBuilder();
            }
            int symbol = sequence.get(i);
            String token = corpus.toToken(symbol);
            if (!corpus.isSeparator(symbol)) {
                if (isFirst) {
                    isFirst = false;
                } else if (!token.equals(",") && !token.equals(".") && !isBlockStart) {
                    sb.append(DELIMITER);
                }
                sb.append(token);
            }
            if (isBlockStart)
                isBlockStart = false;
        }
        documents.add(sb.toString());
        return documents;
    }

    private static String toDocument(Corpus<String> corpus) {
        IntSequence sequence = corpus.sequence();
        StringBuilder sb = new StringBuilder();
        boolean isFirst = true;
        for (int i = 0; i < sequence.size(); i++) {
            int symbol = sequence.get(i);
            String token = corpus.toToken(symbol);
            if (!corpus.isSeparator(symbol)) {
                if (isFirst) {
                    isFirst = false;
                } else if (!token.equals(",") && !token.equals(".")) {
                    sb.append(DELIMITER);
                }
                sb.append(token);
            }
        }
        return sb.toString();
    }

    private static List<Map<String, Integer>> toTermMapList(Corpus<String> corpus) {
        List<Map<String, Integer>> documentList = new ArrayList<Map<String, Integer>>();
        IntSequence sequence = corpus.sequence();
        Map<String, Integer> termMap = null;
        int docNumber = -1;
        boolean isFirst = true;
        for (int i = 0; i < sequence.size(); i++) {
            int docId = corpus.getDocumentId(i);
            if (docNumber != docId) {
                docNumber = docId;
                if (termMap != null)
                    documentList.add(termMap);
                termMap = new TreeMap<String, Integer>();
            }
            int symbol = sequence.get(i);
            String token = corpus.toToken(symbol);
            if (!corpus.isSeparator(symbol)) {
                if (isFirst) {
                    isFirst = false;
                } else if (!token.equals(",") && !token.equals(".")) {
                    Integer count = termMap.get(token);
                    termMap.put(token, (count == null ? 1 : count + 1));
                }
            }
        }
        documentList.add(termMap);
        return documentList;
    }

    private static Map<Integer, Vector> getVectors(Corpus<String> corpus, List<Vector> vectors,
            List<Interval> intervals, List<String> labels) {
        Map<Integer, Vector> vectorMap = new TreeMap<Integer, Vector>();
        for (int i = 0; i < vectors.size(); i++) {
            Vector vector = vectors.get(i);
            String label = toString(intervals.get(i), corpus);
            int index = labels.indexOf(label);
            System.out.printf("testing:%s\n", label);
            if (index != -1) {
                System.out.printf("match:%s\n", label);
                vectorMap.put(index, vector);
            }
        }
        return vectorMap;
    }
}
